import pandas as pd
import joblib
import sys, time, copy
import numpy as np
from prettytable import PrettyTable
import Models_Classifier as models
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn import metrics
from imblearn import over_sampling, under_sampling
import math
from collections import Counter

ALL_features = ['duration', 'srcIP', 'srcPort', 'dstIP', 'dstPort', 'A_timeFirst','A_timeLast', 'A_duration', 'A_numHdrDesc', 'A_numHdrs','A_hdrDesc', 'A_l4Proto', 'A_macStat', 'A_macPairs','A_srcMac_dstMac_numP', 'A_dstPortClassN', 'A_dstPortClass','A_numPktsSnt', 'A_numPktsRcvd', 'A_numBytesSnt', 'A_numBytesRcvd','A_minPktSz', 'A_maxPktSz', 'A_avePktSize', 'A_stdPktSize','A_minIAT', 'A_maxIAT', 'A_aveIAT', 'A_stdIAT', 'A_pktps','A_bytps', 'A_pktAsm', 'A_bytAsm', 'A_tcpFStat', 'A_ipMindIPID','A_ipMaxdIPID', 'A_ipMinTTL', 'A_ipMaxTTL', 'A_ipTTLChg','A_ipTOS', 'A_ipFlags', 'A_ipOptCnt', 'A_ipOptCpCl_Num','A_ip6OptCntHH_D', 'A_ip6OptHH_D', 'A_tcpISeqN', 'A_tcpPSeqCnt','A_tcpSeqSntBytes', 'A_tcpSeqFaultCnt', 'A_tcpPAckCnt','A_tcpFlwLssAckRcvdBytes', 'A_tcpAckFaultCnt', 'A_tcpInitWinSz','A_tcpAveWinSz', 'A_tcpMinWinSz', 'A_tcpMaxWinSz','A_tcpWinSzDwnCnt', 'A_tcpWinSzUpCnt', 'A_tcpWinSzChgDirCnt','A_tcpWinSzThRt', 'A_tcpFlags', 'A_tcpAnomaly', 'A_tcpOptPktCnt','A_tcpOptCnt', 'A_tcpOptions', 'A_tcpMSS', 'A_tcpWS', 'A_tcpMPTBF','A_tcpMPF', 'A_tcpMPAID', 'A_tcpMPdssF', 'A_tcpSSASAATrip','A_tcpRTTAckTripMin', 'A_tcpRTTAckTripMax', 'A_tcpRTTAckTripAve','A_tcpRTTAckTripJitAve', 'A_tcpRTTSseqAA', 'A_tcpRTTAckJitAve','A_icmpStat', 'A_icmpTCcnt', 'A_icmpBFTypH_TypL_Code','A_icmpTmGtw', 'A_icmpEchoSuccRatio', 'A_icmpPFindex', 'A_connSip','A_connDip', 'A_connSipDip', 'A_connSipDprt', 'A_connF','A_nFpCnt', 'A_tCnt', 'A_Ps_Iat_Cnt_PsCnt_IatCnt', 'A_dsMinPl','A_dsMaxPl', 'A_dsMeanPl', 'A_dsLowQuartilePl', 'A_dsMedianPl','A_dsUppQuartilePl', 'A_dsIqdPl', 'A_dsModePl', 'A_dsRangePl','A_dsStdPl', 'A_dsRobStdPl', 'A_dsSkewPl', 'A_dsExcPl','A_dsMinIat', 'A_dsMaxIat', 'A_dsMeanIat', 'A_dsLowQuartileIat','A_dsMedianIat', 'A_dsUppQuartileIat', 'A_dsIqdIat', 'A_dsModeIat','A_dsRangeIat', 'A_dsStdIat', 'A_dsRobStdIat', 'A_dsSkewIat','A_dsExcIat', 'B_timeFirst', 'B_timeLast', 'B_duration','B_numHdrDesc', 'B_numHdrs', 'B_hdrDesc', 'B_l4Proto', 'B_macStat','B_macPairs', 'B_srcMac_dstMac_numP', 'B_dstPortClassN','B_dstPortClass', 'B_numPktsSnt', 'B_numPktsRcvd', 'B_numBytesSnt','B_numBytesRcvd', 'B_minPktSz', 'B_maxPktSz', 'B_avePktSize','B_stdPktSize', 'B_minIAT', 'B_maxIAT', 'B_aveIAT', 'B_stdIAT','B_pktps', 'B_bytps', 'B_pktAsm', 'B_bytAsm', 'B_tcpFStat','B_ipMindIPID', 'B_ipMaxdIPID', 'B_ipMinTTL', 'B_ipMaxTTL','B_ipTTLChg', 'B_ipTOS', 'B_ipFlags', 'B_ipOptCnt','B_ipOptCpCl_Num', 'B_ip6OptCntHH_D', 'B_ip6OptHH_D', 'B_tcpISeqN','B_tcpPSeqCnt', 'B_tcpSeqSntBytes', 'B_tcpSeqFaultCnt','B_tcpPAckCnt', 'B_tcpFlwLssAckRcvdBytes', 'B_tcpAckFaultCnt','B_tcpInitWinSz', 'B_tcpAveWinSz', 'B_tcpMinWinSz','B_tcpMaxWinSz', 'B_tcpWinSzDwnCnt', 'B_tcpWinSzUpCnt','B_tcpWinSzChgDirCnt', 'B_tcpWinSzThRt', 'B_tcpFlags','B_tcpAnomaly', 'B_tcpOptPktCnt', 'B_tcpOptCnt', 'B_tcpOptions','B_tcpMSS', 'B_tcpWS', 'B_tcpMPTBF', 'B_tcpMPF', 'B_tcpMPAID','B_tcpMPdssF', 'B_tcpSSASAATrip', 'B_tcpRTTAckTripMin','B_tcpRTTAckTripMax', 'B_tcpRTTAckTripAve','B_tcpRTTAckTripJitAve', 'B_tcpRTTSseqAA', 'B_tcpRTTAckJitAve','B_icmpStat', 'B_icmpTCcnt', 'B_icmpBFTypH_TypL_Code','B_icmpTmGtw', 'B_icmpEchoSuccRatio', 'B_icmpPFindex', 'B_connSip','B_connDip', 'B_connSipDip', 'B_connSipDprt', 'B_connF','B_nFpCnt', 'B_tCnt', 'B_Ps_Iat_Cnt_PsCnt_IatCnt', 'B_dsMinPl','B_dsMaxPl', 'B_dsMeanPl', 'B_dsLowQuartilePl', 'B_dsMedianPl','B_dsUppQuartilePl', 'B_dsIqdPl', 'B_dsModePl', 'B_dsRangePl','B_dsStdPl', 'B_dsRobStdPl', 'B_dsSkewPl', 'B_dsExcPl','B_dsMinIat', 'B_dsMaxIat', 'B_dsMeanIat', 'B_dsLowQuartileIat','B_dsMedianIat', 'B_dsUppQuartileIat', 'B_dsIqdIat', 'B_dsModeIat','B_dsRangeIat', 'B_dsStdIat', 'B_dsRobStdIat', 'B_dsSkewIat','B_dsExcIat', '0A_PL', '0A_IAT', '1A_PL', '1A_IAT', '2A_PL','2A_IAT', '3A_PL', '3A_IAT', '4A_PL', '4A_IAT', '5A_PL', '5A_IAT','6A_PL', '6A_IAT', '7A_PL', '7A_IAT', '8A_PL', '8A_IAT', '9A_PL','9A_IAT', '10A_PL', '10A_IAT', '11A_PL', '11A_IAT', '12A_PL','12A_IAT', '13A_PL', '13A_IAT', '14A_PL', '14A_IAT', '15A_PL','15A_IAT', '16A_PL', '16A_IAT', '17A_PL', '17A_IAT', '18A_PL','18A_IAT', '19A_PL', '19A_IAT', '0B_PL', '0B_IAT', '1B_PL','1B_IAT', '2B_PL', '2B_IAT', '3B_PL', '3B_IAT', '4B_PL', '4B_IAT','5B_PL', '5B_IAT', '6B_PL', '6B_IAT', '7B_PL', '7B_IAT', '8B_PL','8B_IAT', '9B_PL', '9B_IAT', '10B_PL', '10B_IAT', '11B_PL','11B_IAT', '12B_PL', '12B_IAT', '13B_PL', '13B_IAT', '14B_PL','14B_IAT', '15B_PL', '15B_IAT', '16B_PL', '16B_IAT', '17B_PL','17B_IAT', '18B_PL', '18B_IAT', '19B_PL', '19B_IAT', 'class1','class2', 'class']

feature = ['duration', 'srcIP', 'srcPort', 'dstIP', 'dstPort', 'A_timeFirst', 'A_timeLast', 'A_duration', 'A_numHdrDesc', 'A_numHdrs', 'A_hdrDesc', 'A_l4Proto', 'A_macStat', 'A_macPairs', 'A_srcMac_dstMac_numP', 'A_dstPortClassN', 'A_dstPortClass', 'A_numPktsSnt', 'A_numPktsRcvd', 'A_numBytesSnt', 'A_numBytesRcvd', 'A_minPktSz', 'A_maxPktSz', 'A_avePktSize', 'A_stdPktSize', 'A_minIAT', 'A_maxIAT', 'A_aveIAT', 'A_stdIAT', 'A_pktps', 'A_bytps', 'A_pktAsm', 'A_bytAsm', 'A_tcpFStat', 'A_ipMindIPID', 'A_ipMaxdIPID', 'A_ipMinTTL', 'A_ipMaxTTL', 'A_ipTTLChg', 'A_ipTOS', 'A_ipFlags', 'A_ipOptCnt', 'A_ipOptCpCl_Num', 'A_ip6OptCntHH_D', 'A_ip6OptHH_D', 'A_tcpISeqN', 'A_tcpPSeqCnt', 'A_tcpSeqSntBytes', 'A_tcpSeqFaultCnt', 'A_tcpPAckCnt', 'A_tcpFlwLssAckRcvdBytes', 'A_tcpAckFaultCnt', 'A_tcpInitWinSz', 'A_tcpAveWinSz', 'A_tcpMinWinSz', 'A_tcpMaxWinSz', 'A_tcpWinSzDwnCnt', 'A_tcpWinSzUpCnt', 'A_tcpWinSzChgDirCnt', 'A_tcpWinSzThRt', 'A_tcpFlags', 'A_tcpAnomaly', 'A_tcpOptPktCnt', 'A_tcpOptCnt', 'A_tcpOptions', 'A_tcpMSS', 'A_tcpWS', 'A_tcpMPTBF', 'A_tcpMPF', 'A_tcpMPAID', 'A_tcpMPdssF', 'A_tcpTmS', 'A_tcpTmER', 'A_tcpEcI', 'A_tcpUtm', 'A_tcpBtm', 'A_tcpSSASAATrip', 'A_tcpRTTAckTripMin', 'A_tcpRTTAckTripMax', 'A_tcpRTTAckTripAve', 'A_tcpRTTAckTripJitAve', 'A_tcpRTTSseqAA', 'A_tcpRTTAckJitAve', 'A_icmpStat', 'A_icmpTCcnt', 'A_icmpBFTypH_TypL_Code', 'A_icmpTmGtw', 'A_icmpEchoSuccRatio', 'A_icmpPFindex', 'A_connSip', 'A_connDip', 'A_connSipDip', 'A_connSipDprt', 'A_connF', 'A_nFpCnt', 'A_tCnt', 'A_Ps_Iat_Cnt_PsCnt_IatCnt', 'A_dsMinPl', 'A_dsMaxPl', 'A_dsMeanPl', 'A_dsLowQuartilePl', 'A_dsMedianPl', 'A_dsUppQuartilePl', 'A_dsIqdPl', 'A_dsModePl', 'A_dsRangePl', 'A_dsStdPl', 'A_dsRobStdPl', 'A_dsSkewPl', 'A_dsExcPl', 'A_dsMinIat', 'A_dsMaxIat', 'A_dsMeanIat', 'A_dsLowQuartileIat', 'A_dsMedianIat', 'A_dsUppQuartileIat', 'A_dsIqdIat', 'A_dsModeIat', 'A_dsRangeIat', 'A_dsStdIat', 'A_dsRobStdIat', 'A_dsSkewIat', 'A_dsExcIat', 'B_timeFirst', 'B_timeLast', 'B_duration', 'B_numHdrDesc', 'B_numHdrs', 'B_hdrDesc', 'B_l4Proto', 'B_macStat', 'B_macPairs', 'B_srcMac_dstMac_numP', 'B_dstPortClassN', 'B_dstPortClass', 'B_numPktsSnt', 'B_numPktsRcvd', 'B_numBytesSnt', 'B_numBytesRcvd', 'B_minPktSz', 'B_maxPktSz', 'B_avePktSize', 'B_stdPktSize', 'B_minIAT', 'B_maxIAT', 'B_aveIAT', 'B_stdIAT', 'B_pktps', 'B_bytps', 'B_pktAsm', 'B_bytAsm', 'B_tcpFStat', 'B_ipMindIPID', 'B_ipMaxdIPID', 'B_ipMinTTL', 'B_ipMaxTTL', 'B_ipTTLChg', 'B_ipTOS', 'B_ipFlags', 'B_ipOptCnt', 'B_ipOptCpCl_Num', 'B_ip6OptCntHH_D', 'B_ip6OptHH_D', 'B_tcpISeqN', 'B_tcpPSeqCnt', 'B_tcpSeqSntBytes', 'B_tcpSeqFaultCnt', 'B_tcpPAckCnt', 'B_tcpFlwLssAckRcvdBytes', 'B_tcpAckFaultCnt', 'B_tcpInitWinSz', 'B_tcpAveWinSz', 'B_tcpMinWinSz', 'B_tcpMaxWinSz', 'B_tcpWinSzDwnCnt', 'B_tcpWinSzUpCnt', 'B_tcpWinSzChgDirCnt', 'B_tcpWinSzThRt', 'B_tcpFlags', 'B_tcpAnomaly', 'B_tcpOptPktCnt', 'B_tcpOptCnt', 'B_tcpOptions', 'B_tcpMSS', 'B_tcpWS', 'B_tcpMPTBF', 'B_tcpMPF', 'B_tcpMPAID', 'B_tcpMPdssF', 'B_tcpTmS', 'B_tcpTmER', 'B_tcpEcI', 'B_tcpUtm', 'B_tcpBtm', 'B_tcpSSASAATrip', 'B_tcpRTTAckTripMin', 'B_tcpRTTAckTripMax', 'B_tcpRTTAckTripAve', 'B_tcpRTTAckTripJitAve', 'B_tcpRTTSseqAA', 'B_tcpRTTAckJitAve', 'B_icmpStat', 'B_icmpTCcnt', 'B_icmpBFTypH_TypL_Code', 'B_icmpTmGtw', 'B_icmpEchoSuccRatio', 'B_icmpPFindex', 'B_connSip', 'B_connDip', 'B_connSipDip', 'B_connSipDprt', 'B_connF', 'B_nFpCnt', 'B_tCnt', 'B_Ps_Iat_Cnt_PsCnt_IatCnt', 'B_dsMinPl', 'B_dsMaxPl', 'B_dsMeanPl', 'B_dsLowQuartilePl', 'B_dsMedianPl', 'B_dsUppQuartilePl', 'B_dsIqdPl', 'B_dsModePl', 'B_dsRangePl', 'B_dsStdPl', 'B_dsRobStdPl', 'B_dsSkewPl', 'B_dsExcPl', 'B_dsMinIat', 'B_dsMaxIat', 'B_dsMeanIat', 'B_dsLowQuartileIat', 'B_dsMedianIat', 'B_dsUppQuartileIat', 'B_dsIqdIat', 'B_dsModeIat', 'B_dsRangeIat', 'B_dsStdIat', 'B_dsRobStdIat', 'B_dsSkewIat', 'B_dsExcIat', '0A_PL', '0A_IAT', '1A_PL', '1A_IAT', '2A_PL', '2A_IAT', '3A_PL', '3A_IAT', '4A_PL', '4A_IAT', '5A_PL', '5A_IAT', '6A_PL', '6A_IAT', '7A_PL', '7A_IAT', '8A_PL', '8A_IAT', '9A_PL', '9A_IAT', '10A_PL', '10A_IAT', '11A_PL', '11A_IAT', '12A_PL', '12A_IAT', '13A_PL', '13A_IAT', '14A_PL', '14A_IAT', '15A_PL', '15A_IAT', '16A_PL', '16A_IAT', '17A_PL', '17A_IAT', '18A_PL', '18A_IAT', '19A_PL', '19A_IAT', '0B_PL', '0B_IAT', '1B_PL', '1B_IAT', '2B_PL', '2B_IAT', '3B_PL', '3B_IAT', '4B_PL', '4B_IAT', '5B_PL', '5B_IAT', '6B_PL', '6B_IAT', '7B_PL', '7B_IAT', '8B_PL', '8B_IAT', '9B_PL', '9B_IAT', '10B_PL', '10B_IAT', '11B_PL', '11B_IAT', '12B_PL', '12B_IAT', '13B_PL', '13B_IAT', '14B_PL', '14B_IAT', '15B_PL', '15B_IAT', '16B_PL', '16B_IAT', '17B_PL', '17B_IAT', '18B_PL', '18B_IAT', '19B_PL', '19B_IAT', 'class1', 'class2', 'class']

numfeature = ['A_numPktsSnt', 'A_numPktsRcvd', 'A_numBytesSnt', 'A_numBytesRcvd', 'A_minPktSz', 'A_maxPktSz', 'A_avePktSize', 'A_stdPktSize', 'A_minIAT', 'A_maxIAT', 'A_aveIAT', 'A_stdIAT', 'A_pktps', 'A_bytps', 'A_pktAsm', 'A_bytAsm',  'A_ipOptCnt', 'A_tcpISeqN', 'A_tcpPSeqCnt', 'A_tcpSeqSntBytes', 'A_tcpSeqFaultCnt', 'A_tcpPAckCnt', 'A_tcpFlwLssAckRcvdBytes', 'A_tcpAckFaultCnt', 'A_tcpInitWinSz', 'A_tcpAveWinSz', 'A_tcpMinWinSz', 'A_tcpMaxWinSz', 'A_tcpWinSzDwnCnt', 'A_tcpWinSzUpCnt', 'A_tcpWinSzChgDirCnt', 'A_tcpWinSzThRt', 'A_tcpOptPktCnt', 'A_tcpOptCnt', 'A_tcpMSS', 'A_tcpWS', 'A_tcpSSASAATrip', 'A_tcpRTTAckTripMin', 'A_tcpRTTAckTripMax', 'A_tcpRTTAckTripAve', 'A_tcpRTTAckTripJitAve', 'A_tcpRTTSseqAA', 'A_tcpRTTAckJitAve',  'A_connF', 'A_nFpCnt', 'A_tCnt', 'A_dsMinPl', 'A_dsMaxPl', 'A_dsMeanPl', 'A_dsLowQuartilePl', 'A_dsMedianPl', 'A_dsUppQuartilePl', 'A_dsIqdPl', 'A_dsModePl', 'A_dsRangePl', 'A_dsStdPl', 'A_dsRobStdPl', 'A_dsSkewPl', 'A_dsExcPl', 'A_dsMinIat', 'A_dsMaxIat', 'A_dsMeanIat', 'A_dsLowQuartileIat', 'A_dsMedianIat', 'A_dsUppQuartileIat', 'A_dsIqdIat', 'A_dsModeIat', 'A_dsRangeIat', 'A_dsStdIat', 'A_dsRobStdIat', 'A_dsSkewIat', 'A_dsExcIat','B_numPktsSnt', 'B_numPktsRcvd', 'B_numBytesSnt', 'B_numBytesRcvd', 'B_minPktSz', 'B_maxPktSz', 'B_avePktSize', 'B_stdPktSize', 'B_minIAT', 'B_maxIAT', 'B_aveIAT', 'B_stdIAT', 'B_pktps', 'B_bytps', 'B_pktAsm', 'B_bytAsm', 'B_ipOptCnt', 'B_tcpISeqN', 'B_tcpPSeqCnt', 'B_tcpSeqSntBytes', 'B_tcpSeqFaultCnt', 'B_tcpPAckCnt', 'B_tcpFlwLssAckRcvdBytes', 'B_tcpAckFaultCnt', 'B_tcpInitWinSz', 'B_tcpAveWinSz', 'B_tcpMinWinSz', 'B_tcpMaxWinSz', 'B_tcpWinSzDwnCnt', 'B_tcpWinSzUpCnt', 'B_tcpWinSzChgDirCnt', 'B_tcpWinSzThRt', 'B_tcpOptPktCnt', 'B_tcpOptCnt', 'B_tcpMSS', 'B_tcpWS', 'B_tcpSSASAATrip', 'B_tcpRTTAckTripMin', 'B_tcpRTTAckTripMax', 'B_tcpRTTAckTripAve', 'B_tcpRTTAckTripJitAve', 'B_tcpRTTSseqAA', 'B_tcpRTTAckJitAve', 'B_connF', 'B_nFpCnt', 'B_tCnt', 'B_dsMinPl', 'B_dsMaxPl', 'B_dsMeanPl', 'B_dsLowQuartilePl', 'B_dsMedianPl', 'B_dsUppQuartilePl', 'B_dsIqdPl', 'B_dsModePl', 'B_dsRangePl', 'B_dsStdPl', 'B_dsRobStdPl', 'B_dsSkewPl', 'B_dsExcPl', 'B_dsMinIat', 'B_dsMaxIat', 'B_dsMeanIat', 'B_dsLowQuartileIat', 'B_dsMedianIat', 'B_dsUppQuartileIat', 'B_dsIqdIat', 'B_dsModeIat', 'B_dsRangeIat', 'B_dsStdIat', 'B_dsRobStdIat', 'B_dsSkewIat', 'B_dsExcIat', 'class']

pl_iat = ['0A_PL', '0A_IAT', '1A_PL', '1A_IAT', '2A_PL', '2A_IAT', '3A_PL', '3A_IAT', '4A_PL', '4A_IAT', '5A_PL', '5A_IAT', '6A_PL', '6A_IAT', '7A_PL', '7A_IAT', '8A_PL', '8A_IAT', '9A_PL', '9A_IAT', '10A_PL', '10A_IAT', '11A_PL', '11A_IAT', '12A_PL', '12A_IAT', '13A_PL', '13A_IAT', '14A_PL', '14A_IAT', '15A_PL', '15A_IAT', '16A_PL', '16A_IAT', '17A_PL', '17A_IAT', '18A_PL', '18A_IAT', '19A_PL', '19A_IAT', '0B_PL', '0B_IAT', '1B_PL', '1B_IAT', '2B_PL', '2B_IAT', '3B_PL', '3B_IAT', '4B_PL', '4B_IAT', '5B_PL', '5B_IAT', '6B_PL', '6B_IAT', '7B_PL', '7B_IAT', '8B_PL', '8B_IAT', '9B_PL', '9B_IAT', '10B_PL', '10B_IAT', '11B_PL', '11B_IAT', '12B_PL', '12B_IAT', '13B_PL', '13B_IAT', '14B_PL', '14B_IAT', '15B_PL', '15B_IAT', '16B_PL', '16B_IAT', '17B_PL', '17B_IAT', '18B_PL', '18B_IAT', '19B_PL', '19B_IAT']

###  用户参数---------------------------------------------------------------------------------------
filename = r'D:\test.csv'
train_test_size = 0.1

# 选择分类模型
#modeldict = {'c45':models.c45, 'cart':models.cart,'knn':models.knn,'lrc':models.lrc,'rf10':models.rf10,'rf20':models.rf20,'rf30':models.rf30,'gbdt':models.gbdt,'AdaBoost':models.AdaBoost,'gnb':models.gnb,'lda':models.lda,'qda':models.qda,'svm':models.svm}
Online_model = {'c45':models.c45}
Offline_model = {'c45':models.c45}

# 第二阶段normal作为单独类
normal_class = 1 # 二阶段分类，normal类是否作为单独的类
mod2_nor_train_num = 5000

###  ----------------------------------------------------------------------------------------------
###  --------------------------------不要修改下方内容------------------------------------------------
###  ----------------------------------------------------------------------------------------------

def get_train_test_set(data, a=1, b=2, nor_tor_ratio=1, size=train_test_size):
    """获取训练集和测试集"""
    while(a<b):
        tor_num = len(data[data['class1']!='normal'])
        nor_num = math.ceil(tor_num*a)
        nor = len(data)-tor_num

        type_dict = Counter(data['class1'])
        type_dict['normal']=nor_num
        if(nor<nor_num):
            # 上采样
            sampling = over_sampling.BorderlineSMOTE(kind='borderline-1',sampling_strategy=type_dict,random_state=42)

        else:
            # 下采样
            sampling = under_sampling.RandomUnderSampler(sampling_strategy=type_dict,random_state=0)
        
        dataset = data.values
        features = dataset[::, 0:-1]
        label = dataset[::, -1]
        
        x, y = sampling.fit_resample(features, label)
        print(Counter(y))

        x_train,x_test,y_train,y_test = train_test_split(x,y,test_size=size)
        yield x_train,x_test,y_train,y_test
        a += nor_tor_ratio
        print('a:',a)


### 获得数据----------------------------------------------------------------------------------------
# 读取数据
#data = pd.read_csv(filename ,low_memory=False, delimiter=',')
#data = org_data.replace(' ', 0)

### 设定数据集--------------------------------------------------------------------------------------
# 选择特征和标签
def get_dataset(data, a=1, b=1, nor_tor_ratio=1):
    set_iter = get_train_test_set(data, a, b, nor_tor_ratio)
    while True:
        try:
            # x是特征，y是包括流量类型的label，z是只包括Tor和normal的label
            x_train,x_test,y_train,y_test = next(set_iter)
            z_train = copy.deepcopy(y_train)
            z_train[z_train!='normal']='tor'
            z_test = copy.deepcopy(y_test)
            z_test[z_test!='normal']='tor'
            yield x_train, y_train, z_train, x_test, y_test, z_test
        except StopIteration:
            return

### 分类--------------------------------------------------------------------------------------
# 获取online数据集
def classify(x_train, y_train, z_train, x_test, y_test, z_test):
    table = PrettyTable(['num','time','all-accuracy','all-precision','all-recall','model-1','acc-1','pre-1','rec-1','model-2','acc-2','pre-2','rec-2'])

    n = 1
    # 在线识别
    for model1 in Online_model:
        print("start ",model1)
        clf_1 = Online_model[model1]()
        clf_1.fit(x_train,z_train)
        # 保存模型
        #joblib.dump(clf_1, model+'.m')
        start_1 = time.time()
        predict_1 = clf_1.predict(x_test)
        end_1 = time.time()

        # 模型评估
        acc_1 = accuracy_score(z_test, predict_1)
        #rec_1 = recall_score(z_test, predict_1, pos_label='tor')
        #pre_1 = precision_score(z_test, predict_1, pos_label='tor')
        rec_1 = recall_score(z_test, predict_1, pos_label='tor')
        pre_1 = precision_score(z_test, predict_1, pos_label='tor')

        # 生成二阶段数据
        print(model1, " ", pre_1)
        m_train = x_train[np.where(z_train=='tor')]
        n_train = y_train[np.where(z_train=='tor')]
        m_test = x_test[np.where(predict_1=='tor')]
        n_test = y_test[np.where(predict_1=='tor')]

        if(normal_class > 0 ):
            # 第二阶段模型训练加入normal类
            train_index = np.random.choice(x_train[np.where(z_train=='normal')].shape[0], mod2_nor_train_num, replace=False)
            m_train_normal = x_train[train_index]
            n_train_normal = y_train[train_index]
            m_train = np.vstack((m_train, m_train_normal))
            n_train = np.append(n_train, n_train_normal)

            #test_index = np.random.choice(x_test[np.where(z_train=='normal')].shape[0], mod2_nor_train_num, replace=False)
            #m_test_normal = x_test[test_index]
            #n_test_normal = y_test[test_index]

        # 离线识别
        for model2 in Offline_model:
            clf_2 = Offline_model[model2]()
            clf_2.fit(m_train, n_train)
            # 保存模型
            #joblib.dump(clf_1, model+'.m')
            start_2 = time.time()
            predict_2 = clf_2.predict(m_test)
            end_2 = time.time()

            # 模型评估
            acc_2 = accuracy_score(n_test, predict_2)
            rec_2 = recall_score(n_test, predict_2, average='weighted')
            pre_2 = precision_score(n_test, predict_2, average='weighted')

            table.add_row([n, end_1-start_1+end_2-start_2, 'all-accuracy',pre_2,rec_1*rec_2,model1,acc_1,pre_1,rec_1,model2,acc_2,pre_2,rec_2])
            n += 1

    return(table)

if __name__ == '__main__':
    data = pd.read_csv(filename ,low_memory=False, delimiter=',')

    dataset_iter = get_dataset(data, 1, 1000, 100)
    while True:
        try:
            x_train, y_train, z_train, x_test, y_test, z_test = next(dataset_iter)
            result = classify(x_train, y_train, z_train, x_test, y_test, z_test)
            with open('result.txt', 'a') as f:
                f.write('\n')
                f.write('------------------------------------------------------------------------\n')
                f.write(str(result))
            pass
        except StopIteration:
            sys.exit()
